#include "system.h"
#include "util.h"
#include "type.h"
#include "network.h"
#include "act.h"
#include "train.h"
#include "control.h"
#ifndef NO_TK
#include "display.h"
#include "graph.h"
#endif
#ifdef WIN32
#include <float.h>
#define isnan _isnan
#endif

Algorithm AlgorithmTable = NULL;

void registerAlgorithm(mask code, char *shortName, char *longName, 
		       void (*updateWeights)(flag)) {
  Algorithm A = (Algorithm) safeMalloc(sizeof(struct algorithm), 
				       "registerAlgorithm:A");
  registerType(shortName, code, ALGORITHM);
  A->code = code;
  A->shortName = copyString(shortName);
  A->longName = copyString(longName);
  A->updateWeights = updateWeights;
  A->next = AlgorithmTable;
  AlgorithmTable = A;
  if (eval(".registerAlgorithm %s \"%s\" %d", shortName, longName, code))
    fatalError(Tcl_GetStringResult(Interp));
}
  
void registerAlgorithms(void) {
  registerAlgorithm(STEEPEST, "steepest", "Steepest Descent", 
		    steepestUpdateWeights);
  registerAlgorithm(MOMENTUM, "momentum", "Momentum Descent", 
		    momentumUpdateWeights);
  registerAlgorithm(DOUGS_MOMENTUM, "dougsMomentum", "Doug's Momentum",
		    dougsMomentumUpdateWeights);
#ifdef ADVANCED
  registerAlgorithm(DELTA_BAR_DELTA, "deltaBarDelta", "Delta-Bar-Delta", 
		    deltabardeltaUpdateWeights);
  /*
  registerAlgorithm(QUICK_PROP, "quickProp", "Quick-Prop", 
		    quickpropUpdateWeights);
  */
#endif /* ADVANCED */
}

Algorithm getAlgorithm(mask code) {
  Algorithm A;
  for (A = AlgorithmTable; A && A->code != code; A = A->next);
  if (!A) fatalError("unknown algorithm code: %d", code);
  return A;
}

void printReportHeader(void) {
  print(1, "__Update____Error___UnitCost__Wgt.Cost__Grad.Lin__TimeUsed__"
	"TimeLeft__\n");
}

flag printReport(int lastReport, int update, unsigned long startTime) {
  unsigned long now, remaining;
  char buf[128];
  static unsigned int reportNum;
  static unsigned long last;
  static real perUpdate = 0.0;

  if (!lastReport) reportNum = 1;
  else reportNum++;

  if (Verbosity < 1) return TCL_OK;
  print(1, "%7d)  ", Net->totalUpdates);
  smartPrintReal(Net->error, 8, FALSE);
  print(1, "  ");
  smartPrintReal(Net->outputCost, 8, FALSE);
  print(1, "  ");
  smartPrintReal(Net->weightCost, 8, FALSE);
  print(1, "  ");
  smartPrintReal(Net->gradientLinearity, 8, FALSE);
  print(1, "  ");

  now = getTime();

  printTime((now - startTime) * 1e-3, buf);
  print(1, " %s   ", buf);
  if (Net->numUpdates > update && update != 0) {
    if (reportNum < 5)
      perUpdate = (real) (now - startTime) / update;
    else
      perUpdate = 0.5 * perUpdate + 0.5 * 
	((real) (now - last) / (update - lastReport));
    last = now;
    remaining = (unsigned long) (perUpdate * 1e-3 * 
				 (Net->numUpdates - update));
    printTime(remaining, buf);
  } else printTime(0, buf);
  print(1, "%s\n", buf);

#ifndef NO_TK
  updateDisplays(ON_REPORT);
#endif
  return TCL_OK;
}

/* Assumes there is a net */
flag standardNetTrain(void) {
  int i, lastReport, batchesAtCriterion;
  flag willReport, groupCritReached, value = TCL_OK, done;
  unsigned long startTime;
  Algorithm A;
  
  if (Net->numUpdates < 0)
    return warning("numUpdates (%d) must be positive.", Net->numUpdates);
  if (Net->numUpdates == 0) return result("");
  if (!Net->trainingSet)
    return warning("There is no training set.");
  if (Net->learningRate < 0.0)
    return warning("learningRate (%f) cannot be negative.", Net->learningRate);
  if (Net->momentum < 0.0 && Net->momentum >= 1.0)
    return warning("momentum (%f) is out of range [0,1).", Net->momentum);
  if (Net->weightDecay < 0.0 || Net->weightDecay > 1.0)
    return warning("weightDecay (%f) must be in the range [0,1].",
		   Net->weightDecay);
  if (Net->reportInterval < 0)
    return warning("reportInterval (%d) cannot be negative.", 
		   Net->reportInterval);
  
  A = getAlgorithm(Net->algorithm);

  print(1, "Performing %d updates using %s...\n", 
	Net->numUpdates, A->longName);
  if (Net->reportInterval) printReportHeader();
  
  startTime = getTime();
  lastReport = batchesAtCriterion = 0;
  groupCritReached = FALSE;
  done = FALSE;
  /* It always does at least one update. */
  for (i = 1; !done; i++) {
    RUN_PROC(preEpochProc);

    if ((value = Net->netTrainBatch(&groupCritReached))) break;
    
    if (Net->error < Net->criterion || groupCritReached)
      batchesAtCriterion++;
    else batchesAtCriterion = 0;
    if ((Net->minCritBatches > 0 && batchesAtCriterion >= Net->minCritBatches)
	|| i >= Net->numUpdates) done = TRUE;
    
    willReport = (Net->reportInterval && 
		  ((i % Net->reportInterval == 0) || done))
      ? TRUE : FALSE;
    
    RUN_PROC(postEpochProc);

    /* Here's the weight update (one epoch). */
    A->updateWeights(willReport); 

    RUN_PROC(postUpdateProc);

#ifndef NO_TK
    updateDisplays(ON_UPDATE);
#endif
    
    Net->totalUpdates++;
    
    if (willReport) {
      printReport(lastReport, i, startTime);
      lastReport = i;
    }
    /* Stop if requested. */
    if (smartUpdate(FALSE)) break;
    /* Change the algorithm if requested. */
    if (A->code != Net->algorithm) {
      A = getAlgorithm(Net->algorithm);
      print(1, "Changing algorithm to %s...\n", A->longName);
    }
  }
  startTime = (getTime() - startTime);

#ifndef NO_TK
  updateDisplays(ON_TRAINING);
#endif

  if (value == TCL_ERROR) return TCL_ERROR;
  result("Performed %d updates\n", i - 1);
  if (!done) {
    append("Training halted prematurely\n", i);
    value = TCL_ERROR;
  }
  if (Net->error <= Net->criterion && 
      batchesAtCriterion >= Net->minCritBatches)
    append("Network reached overall error criterion of %f\n", 
	   Net->criterion);
  if (groupCritReached && batchesAtCriterion >= Net->minCritBatches)
    append("Network reached group output criterion\n");
  append("Total time elapsed: %.3f seconds", ((real) startTime * 1e-3));
  
  return value;
}

void updateAdaptiveGain(Group G) {
  real learningRate = Net->adaptiveGainRate;
  real gainDecay = chooseValue(G->gainDecay, Net->gainDecay);
  FOR_EACH_UNIT(G, {
    U->gain -= (learningRate * U->gainDeriv + gainDecay * U->gain);
    if (U->gain < SMALL_VAL) U->gain = SMALL_VAL;
  });
}


/* This splits the processing of entire block into doStats and no doStats
   cases so the no stats inner loop doesn't need to do extra tests */
void steepestUpdateWeights(flag doStats) {
  UPDATE_WEIGHTS({
    w = L->weight;
    lastWeightDelta = -learningRate * L->deriv;
    if (weightDecay > 0.0) lastWeightDelta -= weightDecay * w;
    w += lastWeightDelta;
    if (!isNaN(B->min) && w < B->min) w = B->min;
    else if (!isNaN(B->max) && w > B->max) w = B->max;
    M->lastWeightDelta = w - L->weight;
    L->weight = w;
  });
}

/* This is like steepest except that it uses the momentum. */
void momentumUpdateWeights(flag doStats) {
  UPDATE_WEIGHTS({
    w = L->weight;
    lastWeightDelta = -learningRate * L->deriv + 
      momentum * M->lastWeightDelta;
    if (weightDecay > 0.0) lastWeightDelta -= weightDecay * w;
    w += lastWeightDelta;
    if (!isNaN(B->min) && w < B->min) w = B->min;
    else if (!isNaN(B->max) && w > B->max) w = B->max;
    M->lastWeightDelta = w - L->weight;
    L->weight = w;
  });
}

/* This is exactly like momentum but the length of the weight delta vector 
   (before momentum) is always exactly the learning rate. */
void dougsMomentumUpdateWeights(flag doStats) {
  double scale, sum = 0.0;
  FOR_EACH_GROUP({if (G->type & FROZEN) continue;
    FOR_EACH_UNIT(G, {
      Link L; Link sL;
      if (U->type & FROZEN) continue;
      L = U->incoming;
      FOR_EACH_BLOCK(U, {
	if (B->type & FROZEN) {L += B->numUnits; continue;}
	for (sL = L + B->numUnits; L < sL; L++)
	  sum += SQUARE(L->deriv);
      });
    });
  });
  scale = (sum > 1.0) ? 1.0 / SQRT(sum) : 1.0;
  UPDATE_WEIGHTS({
    w = L->weight;
    lastWeightDelta = -learningRate * scale * L->deriv + 
      momentum * M->lastWeightDelta;
    if (weightDecay > 0.0) lastWeightDelta -= weightDecay * w;
    w += lastWeightDelta;
    if (!isNaN(B->min) && w < B->min) w = B->min;
    else if (!isNaN(B->max) && w > B->max) w = B->max;
    M->lastWeightDelta = w - L->weight;
    L->weight = w;
  });
}


#ifdef ADVANCED
/* The link learning rate is stored in the lastValue field. */
/* The link momentum is stored in the lastWeightDelta field. */

#define OPPOSITE_SIGN(x,y) (IS_NEGATIVE(x) ^ IS_NEGATIVE(y))

void deltabardeltaUpdateWeights(flag doStats) {
  real rateIncrement = Net->rateIncrement, rateDecrement = Net->rateDecrement,
    linkLearningRate;
  UPDATE_WEIGHTS({
    lastWeightDelta = M->lastWeightDelta;
    deriv = L->deriv;
    linkLearningRate = M->lastValue;
	  
    if (OPPOSITE_SIGN(deriv, lastWeightDelta))
      linkLearningRate += rateIncrement;
    else linkLearningRate *= rateDecrement;
    lastWeightDelta = -linkLearningRate * learningRate * deriv + 
      momentum * lastWeightDelta;
    w = L->weight;
    if (weightDecay > 0.0) lastWeightDelta -= weightDecay * w;
    w += lastWeightDelta;
    if (!isNaN(B->min) && w < B->min) w = B->min;
    else if (!isNaN(B->max) && w > B->max) w = B->max;
    M->lastWeightDelta = w - L->weight;
    L->weight = w;
    M->lastValue = linkLearningRate;
  });
}

#ifdef JUNK
/* If lastWeightDelta is 0, I use steepest.  Is this best? */
/* lastValue stores the lastDeriv */
void quickpropUpdateWeights(flag doStats) {
  UPDATE_WEIGHTS({
    deriv = L->deriv;
    w = L->weight;
    lastWeightDelta = M->lastWeightDelta;
    if (lastWeightDelta == 0.0)
      lastWeightDelta = -learningRate * 0.1 * deriv;
    else /* I multiply by the learning rate here to stabilize it */
      lastWeightDelta = learningRate * 
	(deriv * lastWeightDelta) / (M->lastValue - deriv);
    M->lastValue = deriv;
    if (weightDecay > 0.0) lastWeightDelta -= weightDecay * w;
    w += lastWeightDelta;
    if (!isNaN(B->min) && w < B->min) w = B->min;
    else if (!isNaN(B->max) && w > B->max) w = B->max;
    M->lastWeightDelta = w - L->weight;
    L->weight = w;
  }, {
    deriv = L->deriv;
    w = L->weight;
    lastWeightDelta = M->lastWeightDelta;
    gradLin -= lastWeightDelta * deriv;
    lastDeltaLen += SQUARE(lastWeightDelta);
    derivLen += SQUARE(deriv);
    if (lastWeightDelta == 0.0)
      lastWeightDelta = -learningRate * 0.1 * deriv;
    else /* I multiply by the learning rate here to stabilize it */
      lastWeightDelta = learningRate * 
	(deriv * lastWeightDelta) / (M->lastValue - deriv);
    M->lastValue = deriv;
    if (weightDecay > 0.0) lastWeightDelta -= weightDecay * w;
    w += lastWeightDelta;
    if (!isNaN(B->min) && w < B->min) w = B->min;
    else if (!isNaN(B->max) && w > B->max) w = B->max;
    M->lastWeightDelta = w - L->weight;
    L->weight = w;
    weightCost += SQUARE(w);
  });
}
#endif /* JUNK */
#endif /* ADVANCED */
